(**
 * @copyright (C) 2021 SML# Development Team.
 * @author Atsushi Ohori
 * @author Liu Bochao
 * @author YAMATODANI Kiyoshi
 * @version $Id: TypeInferenceContext.sml,v 1.32 2008/05/31 12:18:23 ohori Exp $
 *)
structure TypeInferenceContext =
struct
  local
    structure T = Types
    structure IC = IDCalc
    structure TU = TypesBasics
    structure TC = TypedCalc
    structure FE = SMLFormat.FormatExpression

    (*% 
       @formatter(TC.idstatus) TC.formatWithType_idstatus
     *)
    type idstatus 
     = (*%
          @format(x) x
        *)
       TC.idstatus

  in

  (*%
   * @formatter(TvarMap.map) IDCalc.formatEnclosedTvarMap
   * @formatter(T.ty) T.format_ty
   *)
  type tvarEnv =
      (*%
         @format(ty map)
         map(ty)("{", ":", ",", "}")
       *)
      T.ty TvarMap.map

  fun formatList (formatter, sep) nil = nil
    | formatList (formatter, sep) (elem::elems) =
      FE.Sequence (formatter elem)
      :: map (fn x => FE.Sequence (FE.Sequence sep :: formatter x)) elems

  fun formatEnclosedVarMap (formatter, mapsto, comma) map =
      formatList
        (fn (var, item) =>
            FE.Sequence (IDCalc.format_varInfo var)
            :: FE.Sequence mapsto
            :: formatter item,
         comma
        )
        (VarMap.listItemsi map)

  fun formatEnclosedOPrimMap (formatter, mapsto, comma) map =
      formatList
        (fn (oprimInfo, item) =>
            FE.Sequence (IDCalc.format_oprimInfo oprimInfo)
            :: FE.Sequence mapsto
            :: formatter item,
         comma
        )
        (OPrimMap.listItemsi map)

  (*%
     @formatter(VarMap.map) formatEnclosedVarMap
   *)
  type varEnv =
      (*%
         @format(var map) map(var)(":",+2)
       *)
      idstatus VarMap.map

  (*%
     @formatter(OPrimMap.map) formatEnclosedOPrimMap
     @formatter(T.oprimInfo) T.format_oprimInfo
   *)
  type oprimEnv =
      (*%
       * @format(oprim map) map(oprim)(":",+2)
       *)
      T.oprimInfo OPrimMap.map

  (*%
   *)
  type context
    = (*%
         @format({tvarEnv, varEnv, oprimEnv})
          "{"
           1[
             +1
              1["tvarEnv:"+1 tvarEnv]
             +1
              1["varEnv:"+1 varEnv]
             +1
              1["oprimEnv:"+1 oprimEnv]
            ]
           +1
           "}"
      *)
    {
     tvarEnv: tvarEnv,
     varEnv: varEnv,
     oprimEnv: oprimEnv
    }

  val emptyContext = 
      {
       tvarEnv = TvarMap.empty,
       varEnv = VarMap.empty,
       oprimEnv = OPrimMap.empty
       } : context

  fun bindVar (lambdaDepth:int,
               ({oprimEnv,tvarEnv,varEnv} : context),
               (var:IDCalc.varInfo),
               idstatus:idstatus) =
    (
     TU.adjustDepthInTy (ref false) lambdaDepth (case idstatus of
                                       TC.VARID {ty,...} => ty
                                     | TC.RECFUNID ({ty,...}, arity) => ty
                                    );
     {
      tvarEnv = tvarEnv,
      varEnv = VarMap.insert(varEnv, var, idstatus),
      oprimEnv = oprimEnv
     } : context
    )

  fun bindOPrim ({tvarEnv,varEnv,oprimEnv} : context,
                 oprimInfoKey:IC.oprimInfo,
                 oprimInfo:T.oprimInfo) 
                   =
      {
        tvarEnv = tvarEnv,
        varEnv = varEnv,
        oprimEnv = OPrimMap.insert (oprimEnv, oprimInfoKey, oprimInfo)
      } : context
      
  fun extendContextWithVarEnv
        ({oprimEnv, tvarEnv, varEnv} : context, newVarEnv :varEnv) = 
      {
       tvarEnv = tvarEnv,
       varEnv = VarMap.unionWith #2 (varEnv, newVarEnv),
       oprimEnv = oprimEnv
      } : context

(*
  fun extendContextWithTvarEnv
        ({oprimEnv,tvarEnv,varEnv}: context, newTvarEnv:tvarEnv) =
    {
     tvarEnv = SEnv.unionWith #1 (newTvarEnv, tvarEnv),
     varEnv = varEnv,
     oprimEnv = oprimEnv
    } : context
*)

  fun extendContextWithContext
      ({oprimEnv, tvarEnv, varEnv} : context,
       {oprimEnv=newOPrimEnv,
        tvarEnv=newTvarEnv,
        varEnv=newVarEnv} : context) =
      {
       oprimEnv = OPrimMap.unionWith #2 (oprimEnv, newOPrimEnv),
       tvarEnv = TvarMap.unionWith #2 (tvarEnv, newTvarEnv),
       varEnv = VarMap.unionWith #2 (varEnv, newVarEnv)
      } : context

  fun overrideContextWithTvarEnv
        ({oprimEnv,tvarEnv,varEnv} : context,
         newTvarEnv : tvarEnv) =
    {oprimEnv=oprimEnv,
     tvarEnv = newTvarEnv,
     varEnv = varEnv
    } : context

(*
  fun lookupTvarInContext
        ({oprimEnv, tvarEnv,varEnv} : context, string) 
      : T.ty option =
      case SEnv.find(tvarEnv, string) of
        SOME tvStateRef => SOME(T.TYVARty tvStateRef)
      | NONE => NONE
*)
  fun addUtvar 
        (lambdaDepth:int)
        ({oprimEnv,tvarEnv,varEnv}:context)
        (kindedTvarList:(T.utvar * IDCalc.tvarKind) list) 
        (loc:Loc.loc) 
       : context * (Types.tvState ref * IDCalc.tvarKind) TvarMap.map =
      let
        val (newTvarEnv, addedUtvars) = 
            foldl
              (fn
               (
                (utvar as {symbol, id, isEq, lifted}, tvarKind),
                (newTvarEnv, addedUtvars)
               ) =>
               let 
                 val newTvStateRef = T.newUtvar (lambdaDepth, utvar)
               in 
                 (TvarMap.insert(newTvarEnv, utvar, T.TYVARty newTvStateRef),
                  TvarMap.insert(addedUtvars, utvar, (newTvStateRef, tvarKind))
                 )
               end)
              (tvarEnv, TvarMap.empty)
              kindedTvarList
      in
        ({oprimEnv=oprimEnv,
          tvarEnv = newTvarEnv,
          varEnv=varEnv
         },
         addedUtvars)
      end

end
end
